// Copyright 2022 Huawei Cloud Computing Technology Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cerrno>
#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <sys/socket.h>
#include <linux/tcp.h>
#include <unistd.h>
#include <endian.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <securec.h>
#include "CasTcpSocket.h"
#include "../cas_common/CasLog.h"

using namespace std;

void CasClientNotice::Notice(void *socket)
{
    CasSocket *casSocket = (CasSocket *)socket;
    if (socket == nullptr) {
        ERR("Socket is null.");
        return;
    }
    CasEventNotice *reconnectNotice = casSocket->GetReconnectNotice();
    if (reconnectNotice == nullptr) {
        ERR("Reconnect notice is null.");
        return;
    }
    reconnectNotice->Notice(socket);
}

CasClientNotice::~CasClientNotice() {}

int CasTcpClientSocket::CasCreateTcpClient(uint32_t socketOption, uint32_t remoteIp, uint16_t remotePort,
    uint32_t localIp, uint16_t localPort)
{
    this->m_remoteIp = remoteIp;
    this->m_remotePort = remotePort;
    this->m_localIp = localIp;
    this->m_localPort = localPort;
    this->m_socketOption = socketOption;
    this->msg_valid_map.clear();
    this->msg_no_map.clear();

    int flag = this->CreateSocket();
    if (flag == -1) {
        ERR("Failed to create socket.");
        return flag;
    }

    this->m_eventNotice = new CasClientNotice();
    return 0;
}

CasTcpClientSocket::CasTcpClientSocket(uint32_t remoteIp, uint16_t remotePort, uint32_t localIp, uint16_t localPort)
{
    CasCreateTcpClient(SOCKET_OPTION_BITSET_NO_DELAY | SOCKET_OPTION_BITSET_QUICK_ACK, remoteIp, remotePort, localIp,
        localPort);
}

CasTcpClientSocket::CasTcpClientSocket(uint32_t socketOption, uint32_t remoteIp, uint16_t remotePort, uint32_t localIp,
    uint16_t localPort)
{
    CasCreateTcpClient(socketOption, remoteIp, remotePort, localIp, localPort);
}

CasTcpClientSocket::~CasTcpClientSocket()
{
    m_status = SOCKET_STATUS_EXIT;

    if (-1 != m_fd) {
        shutdown(m_fd, SHUT_RDWR);
        close(m_fd);
    }

    m_fd = -1;

    if (this->m_eventNotice != nullptr) {
        delete this->m_eventNotice;
        this->m_eventNotice = nullptr;
    }
}

int CasTcpClientSocket::CreateSocket()
{
    int socketFd;
    struct linger tcpLinger;
    struct timeval timeVal;
    int reuse = 1;
    int flag = 1;

    m_status = SOCKET_STATUS_INIT;
    socketFd = socket(PF_INET, SOCK_STREAM, 0);
    if (socketFd < 0) {
        ERR("Create socket failed errno (%d): %s.", errno, strerror(errno));
        return -1;
    }

    if (setsockopt(socketFd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
        close(socketFd);
        ERR("Set socket opt REUSER failed errno (%d): %s.", errno, strerror(errno));
        return -1;
    }

    tcpLinger.l_onoff = 1;
    tcpLinger.l_linger = 0;

    if (setsockopt(socketFd, SOL_SOCKET, SO_LINGER, &tcpLinger, sizeof(tcpLinger)) < 0) {
        close(socketFd);
        ERR("Set socket opt LINGER failed errno (%d): %s.", errno, strerror(errno));
        return -1;
    }

    timeVal.tv_sec = 1;
    timeVal.tv_usec = 0;

    if (setsockopt(socketFd, SOL_SOCKET, SO_SNDTIMEO, &timeVal, sizeof(timeVal)) < 0) {
        close(socketFd);
        ERR("Set socket opt SNDTIMEO failed errno (%d): %s.", errno, strerror(errno));
        return -1;
    }

    if (setsockopt(socketFd, SOL_SOCKET, SO_RCVTIMEO, &timeVal, sizeof(timeVal)) < 0) {
        close(socketFd);
        ERR("Set socket opt RCVTIMEO failed errno (%d): %s.", errno, strerror(errno));
        return -1;
    }

    if (m_socketOption & SOCKET_OPTION_BITSET_NO_DELAY) {
        if (setsockopt(socketFd, IPPROTO_TCP, TCP_NODELAY, (const char *)&flag, sizeof(flag)) < 0) {
            close(socketFd);
            ERR("Set socket opt TCP_NODELAY failed errno (%d): %s.", errno, strerror(errno));
            return -1;
        }
    }
    if (m_socketOption & SOCKET_OPTION_BITSET_QUICK_ACK) {
        if (setsockopt(socketFd, IPPROTO_TCP, TCP_QUICKACK, (const char *)&flag, sizeof(flag)) < 0) {
            close(socketFd);
            ERR("Set socket opt TCP_QUICKACK failed errno (%d): %s.", errno, strerror(errno));
            return -1;
        }
    }
    this->m_fd = socketFd;
    return this->m_fd;
}

int CasTcpClientSocket::Connect()
{
    int connectRes;
    struct sockaddr_in localAddr;
    struct sockaddr_in remoteAddr;
    memset_s(&localAddr, sizeof(localAddr), 0, sizeof(localAddr));
    memset_s(&remoteAddr, sizeof(remoteAddr), 0, sizeof(remoteAddr));

    char serverIpAddr[INET_ADDRSTRLEN];

    localAddr.sin_family = AF_INET;
    localAddr.sin_addr.s_addr = htonl(m_localIp);
    localAddr.sin_port = htons(m_localPort);

    remoteAddr.sin_family = AF_INET;
    remoteAddr.sin_addr.s_addr = htonl(m_remoteIp);
    remoteAddr.sin_port = htons(m_remotePort);

    if (inet_ntop(AF_INET, (void *)&remoteAddr.sin_addr, serverIpAddr, (socklen_t)INET_ADDRSTRLEN) == nullptr) {
        ERR("Failed to get server ip.");
    }
    SetStatus(SOCKET_STATUS_INIT);
    connectRes = ::connect(m_fd, (struct sockaddr *)&(remoteAddr), sizeof(remoteAddr));
    if (connectRes != 0) {
        ERR("Connect failed res %d", connectRes);
        return -1;
    } else {
        SetStatus(SOCKET_STATUS_RUNNING);
    }

#if USE_TLS
    int configRes = ConfigSSL();
    if (configRes != SSL_CONFIG_SUCCESS) {
        ERR("Failed to config SSL res %d", configRes);
        SetStatus(SOCKET_STATUS_INIT);
        if (configRes == SSL_CONFIG_SOCKET_CLOSE) {
            return -2;
        } else {
            return -1;
        }
    }
#endif
    return connectRes;
}

int CasTcpClientSocket::Reconnect()
{
    if (this->m_fd != -1) {
        shutdown(m_fd, SHUT_RDWR);
        close(m_fd);
        this->m_fd = -1;
    }

    int flag = this->CreateSocket();
    if (flag == -1) {
        ERR("Failed to create socket.");
        return flag;
    }

    return this->Connect();
}
